{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "PyTorch_XLA_Debug",
      "provenance": [],
      "collapsed_sections": [],
      "machine_shape": "hm"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-5-xIyOwIk45",
        "colab_type": "text"
      },
      "source": [
        "Install PyTorch Nightly packages and set up the backend version."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7ktHVEtJr2Qg",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rW9GanW1r3cC",
        "colab_type": "text"
      },
      "source": [
        "Only run the below commented cell if you would like a nightly release"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sPJVqAKyml5W",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# VERSION = \"nightly\"  #@param [\"1.5\" , \"20200325\", \"nightly\"]\n",
        "# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py\n",
        "# !python pytorch-xla-env-setup.py --version $VERSION"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AxoPg3MPIz6d",
        "colab_type": "text"
      },
      "source": [
        "Install the other publicly available dependencies (PIP, APT, ...)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cb8MzLOAKKXS",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# !pip install transformers"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cAZS4UD_I89G",
        "colab_type": "text"
      },
      "source": [
        "Clone the repo containing the model to be tested.\n",
        "\n",
        "If all the code fits a single code snippet (see below the *%%writefile* cell), you can leave the cell below empty or remove it."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gyrcRaWFJrBf",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!rm -rf pytorch-xla-transformer-language-model/\n",
        "!git clone https://github.com/dlibenzi/pytorch-xla-transformer-language-model.git"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bxdfcIREJF97",
        "colab_type": "text"
      },
      "source": [
        "Setup the environment."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2hXaIwi3Kr_1",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import os\n",
        "# os.environ['XLA_IR_DEBUG'] = '1'\n",
        "# os.environ['XLA_HLO_DEBUG'] = '1'\n",
        "# os.environ['TF_CPP_VMODULE'] = 'tensor=5'\n",
        "# os.environ['XLA_USE_32BIT_LONG'] = '1'\n",
        "# os.environ['XLA_SAVE_TENSORS_FILE'] = 'tensors.log'\n",
        "# os.environ['XLA_SAVE_TENSORS_FMT'] = 'text'\n",
        "# os.environ['XLA_TRIM_GRAPH_SIZE'] = '1000000'"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BUVjzXMjJN5Q",
        "colab_type": "text"
      },
      "source": [
        "Override the files which needs editing/tweaking during the debug session.\n",
        "\n",
        "This means copy and pasting the content of one or more of the original github repo files, so that one can easily iterate with debugging. If the test/debug code does not need to pull any github repo, the cell(s) below are essentially pasting the test code itself.\n",
        "\n",
        "We strongly suggest to run single core when debugging. If using multi-processing, just pass *nprocs=1* to *xmp.spawn()*.\n",
        "\n",
        "In case accuracy debugging is not needed, to avoid fetching large datasets, it is possible to use the PyTorch/XLA [data generators](https://github.com/pytorch/xla/blob/dfab0b544c02b5319c3d52bef12cf4487829c182/test/test_train_mp_mnist.py#L61).\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zQ2_OcQxMEI8",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "%%writefile pytorch-xla-transformer-language-model/train.py\n",
        "# Copyright (c) 2019, Bryan McCann\n",
        "# All rights reserved.\n",
        "\n",
        "import os\n",
        "import time\n",
        "import math\n",
        "\n",
        "import numpy\n",
        "import torch\n",
        "import torch.utils.data\n",
        "\n",
        "import torch_xla\n",
        "import torch_xla.debug.metrics as met\n",
        "import torch_xla.distributed.parallel_loader as pl\n",
        "import torch_xla.utils.utils as xu\n",
        "import torch_xla.core.xla_model as xm\n",
        "import torch_xla.distributed.xla_multiprocessing as xmp\n",
        "\n",
        "from transformer import Transformer\n",
        "\n",
        "\n",
        "class LazyDataset:\n",
        "\n",
        "  def __init__(self, path, sequence_length):\n",
        "    self.path = path\n",
        "    self.size = os.stat(path).st_size - sequence_length\n",
        "    self.sequence_length = sequence_length\n",
        "\n",
        "  def __getitem__(self, index):\n",
        "    with open(self.path, 'rb') as f:\n",
        "      f.seek(index)\n",
        "      chunk = f.read(self.sequence_length)\n",
        "    return torch.ByteTensor(numpy.frombuffer(chunk, dtype=numpy.uint8))\n",
        "\n",
        "  def __len__(self):\n",
        "    return self.size\n",
        "\n",
        "\n",
        "def get_dataloader(path, batch_size, sequence_length, num_workers):\n",
        "  dataset = LazyDataset(path, sequence_length + 1)\n",
        "  if xm.xrt_world_size() > 1:\n",
        "    sampler = torch.utils.data.distributed.DistributedSampler(\n",
        "        dataset,\n",
        "        num_replicas=xm.xrt_world_size(),\n",
        "        rank=xm.get_ordinal(),\n",
        "        shuffle=True)\n",
        "  else:\n",
        "    sampler = torch.utils.data.RandomSampler(dataset)\n",
        "  return torch.utils.data.DataLoader(\n",
        "      dataset, sampler=sampler, batch_size=batch_size, num_workers=num_workers)\n",
        "\n",
        "\n",
        "def main(index):\n",
        "  BATCH_SIZE = 128\n",
        "  LOG_STEPS = 10\n",
        "  METRICS_STEP = 50\n",
        "  NUM_EPOCHS = 8\n",
        "  SEQUENCE_LENGTH = 256\n",
        "\n",
        "  device = xm.xla_device()\n",
        "  model = Transformer(256, 12, 512, 2048, 8, 0.2).to(device)\n",
        "  optimizer = torch.optim.SGD(model.parameters(), lr=0.001)\n",
        "\n",
        "  def train_loop_fn(loader):\n",
        "    tracker = xm.RateTracker()\n",
        "\n",
        "    positions = torch.arange(SEQUENCE_LENGTH).long().view(\n",
        "        1, SEQUENCE_LENGTH).to(device)\n",
        "    causal_mask = torch.triu(\n",
        "        torch.ones(\n",
        "            SEQUENCE_LENGTH, SEQUENCE_LENGTH, dtype=torch.uint8, device=device),\n",
        "        diagonal=1).unsqueeze(0)\n",
        "\n",
        "    model.train()\n",
        "    for iteration, batch in enumerate(loader):\n",
        "      input = batch[:, :-1].long()\n",
        "      target = batch[:, 1:].long()\n",
        "\n",
        "      loss = model(input, positions, target, batch_mask=causal_mask)\n",
        "      loss.backward()\n",
        "      xm.optimizer_step(optimizer)\n",
        "\n",
        "      tracker.add(BATCH_SIZE)\n",
        "      if iteration % LOG_STEPS == 0:\n",
        "        print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format(\n",
        "            device, iteration,\n",
        "            loss.item() / math.log(2), tracker.rate()))\n",
        "      if iteration % METRICS_STEP == 0:\n",
        "        xm.master_print(met.metrics_report())\n",
        "\n",
        "  train_loader = get_dataloader('pytorch-xla-transformer-language-model/datasets/enwik8/train/train.txt.raw',\n",
        "                                BATCH_SIZE, SEQUENCE_LENGTH, 0)\n",
        "\n",
        "  for epoch in range(0, NUM_EPOCHS):\n",
        "    para_loader = pl.ParallelLoader(train_loader, [device])\n",
        "    train_loop_fn(para_loader.per_device_loader(device))\n",
        "\n",
        "\n",
        "if __name__ == '__main__':\n",
        "  # Set nprocs=1 for debugging (using one core).\n",
        "  xmp.spawn(main, args=(), nprocs=1)\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lcSHTc_uJYuM",
        "colab_type": "text"
      },
      "source": [
        "Cleanup (optional) the products of previous runs, as some operations might append to existing content (like tensors logging)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "46T_FzvTU1KE",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!rm -f tensors.log\n",
        "!rm -rf /tmp/debug_run*"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "snzW8KikJlvK",
        "colab_type": "text"
      },
      "source": [
        "Run the model's script with proper command line."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_aceXPFHJ1Zq",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!python pytorch-xla-transformer-language-model/train.py"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TC-f3okfGQHX",
        "colab_type": "text"
      },
      "source": [
        "For debugging it is also useful to run the *debug_run.py* script to collect a set of debug information packaged in a TAR file.\n",
        "\n",
        "The *debug_run.py* command below should be run for a few steps (around 10 should be enough), or stopped after a given time if hanging happen."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xq-prSeCGH0J",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!git clone https://github.com/pytorch/xla.git"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SFpZLS98HBIq",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!./xla/scripts/debug_run.py --outfile debug_run.tar.gz --hlo -- python -u pytorch-xla-transformer-language-model/train.py"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jHiJZTTIJq5y",
        "colab_type": "text"
      },
      "source": [
        "Download generated debug files or logs."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VCowkX_-Ofkn",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from google.colab import files\n",
        "# files.download('tensors.log')\n",
        "# files.download('debug_run.tar.gz')"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}
